import os
from statistics import stdev
import pickle
import numpy as np
from MCTS_utils import eval_agent, sample_child
import MCTS_utils

eps = 0.1
mode = 'ucb-v'

def get_num_total_evals():
    return MCTS_utils.nodes[0].get_sum(lambda node: node.num_evals)

class Node:
    def __init__(self,
                commit_id, 
                utility_measures=None, 
                parent_id=None, 
                id=None,):
        self.commit_id = commit_id
        self.children = []
        if utility_measures:
            self.utility_measures = utility_measures
        else:
            self.utility_measures = []
        self.parent_id = parent_id
        if id is None:  #        
            self.id = len(MCTS_utils.nodes)
        else:
            self.id = id
        MCTS_utils.nodes[self.id] = self

    def measure(self, pending_tasks, n=1, random_level=.5):
        self.utility_measures += eval_agent(self.commit_id, num_tasks=n, pending_tasks=pending_tasks, random_level=random_level, init_agent_path=os.path.join(MCTS_utils.output_dir, 'initial/src'))

    def expand(self, use_own_docker_image=False):
        if self.commit_id == 'failed':
            raise RuntimeError("Cannot expand a failed node.")
        child_commit_id = sample_child(self.commit_id, max_try=2, use_own_docker_image=use_own_docker_image)
        if child_commit_id == 'failed':
            return False
        new_node = Node(commit_id=child_commit_id, parent_id=self.id)
        self.children.append(new_node)
        return True

    def get_best(self, metric):
        if len(self.children) == 0:
            return self
        else:
            return max([child.get_best(metric) for child in self.children] + [self], key=metric)
        
    def get_sum(self, metric):
        if len(self.children) == 0:
            return metric(self)
        else:
            return sum(child.get_sum(metric) for child in self.children) + metric(self)
        
    def get_sub_tree(self, fn=lambda self: self):
        if len(self.children) == 0:
            return [fn(self)]
        else:
            nodes_list = [fn(self)]
            for child in self.children:
                nodes_list.extend(child.get_sub_tree(fn))
            return nodes_list

    @property
    def utility_exploration_bound(self):
        if self.num_evals == 0:
            return np.inf
        if self.commit_id == 'failed':
            return -np.inf
        if mode == 'ucb1-tuned':
            return (np.log(get_num_total_evals()) / self.num_evals * \
                min(1/4, np.var(self.utility_measures) + (2 * np.log(get_num_total_evals()) / self.num_evals) ** 0.5)) ** 0.5
        elif mode == 'ucb-v':
            return (2 * np.var(self.utility_measures) * eps / self.num_evals) ** 0.5 + 3 * eps / self.num_evals
        
    def utility_ucb(self):
        if self.num_evals == 0:
            return np.inf
        if self.commit_id == 'failed':
            return -np.inf
        return self.mean_utility + self.utility_exploration_bound

    def get_pseudo_decendant_evals(self, num_pseudo):
        return self.utility_measures if self.num_evals < num_pseudo else [self.mean_utility] * num_pseudo
    
    def get_decendant_evals(self, num_pseudo=10):
        decendant_evals = self.get_pseudo_decendant_evals(num_pseudo)
        for decendant in self.get_sub_tree()[1:]:
            decendant_evals += decendant.utility_measures

        return decendant_evals

    @property
    def descendants_exploration_bound(self):
        if self.commit_id == 'failed':
            return -np.inf
        
        decendant_evals = self.get_decendant_evals()
        if len(decendant_evals) == 0:
            return np.inf
        if mode == 'ucb1-tuned':
            return (np.log(get_num_total_evals()) / len(decendant_evals) * \
                min(1/4, np.var(decendant_evals) + (2 * np.log(get_num_total_evals()) / len(decendant_evals)) ** 0.5)) ** 0.5
        elif mode == 'ucb-v':
            return (2 * np.var(decendant_evals) * eps / len(decendant_evals)) ** 0.5 + 3 * eps / len(decendant_evals)

    def descendants_ucb(self):
        decendant_evals = self.get_decendant_evals()
        if self.commit_id == 'failed':
            return -np.inf
        if len(decendant_evals) == 0:
            return np.inf
        return np.mean(decendant_evals) + self.descendants_exploration_bound

    def get_value_estimate(self):
        if self.parent_id is None:
            self.value_estimate = .2
            self.std_estimate = 0
            return
        bar_Y_is = [max(0, child.mean_utility) for child in self.children]
        s_eps = np.array([stdev(child.utility_measures) if child.num_visits > 1 and not np.isnan(child.utility_measures).all() else 1 / 2 for child in self.children])
        hat_mu_X = np.mean(bar_Y_is)
        v_is = s_eps ** 2 / np.array([child.num_visits for child in self.children])
        bar_Y_w = np.sum(v_is * bar_Y_is) / np.sum(v_is)
        Q = np.sum(v_is * (bar_Y_is - bar_Y_w) ** 2) 
        hat_var_X = max(0, (Q - len(self.children)) / np.sum(v_is) - np.sum(v_is ** 2) / np.sum(v_is))
        w_is = 1 / (v_is + hat_var_X)
        hat_mu_X = np.sum(w_is * bar_Y_is) / np.sum(w_is)
        std_hat_mu_X = 1 / np.sum(w_is) ** 0.5

    @property
    def std_utility(self):
        if self.num_visits <= 1:
            return .5
        else:
            return stdev(self.utility_measures)

    @property
    def sum_utility(self):
        return sum(self.utility_measures)
    
    @property
    def num_evals(self):
        return len(self.utility_measures)
    
    @property
    def mean_utility(self):
        if self.num_evals == 0:
            return np.inf
        return self.sum_utility / self.num_evals

    def add_child(self, child):
        self.children.append(child)

    def save_as_dict(self):
        return {
            'commit_id': self.commit_id,
            'id': self.id,
            'parent_id': self.parent_id,
            'mean_utility': self.mean_utility,
            'num_evals': self.num_evals,
        }


def is_pickleable(obj):
    try:
        pickle.dumps(obj)
        return True
    except Exception as e:
        print(f"Pickle error: {e}")
        return False
    

# if __name__ == "__main__":
#     print(is_pickleable(Node('alskdjf230r43')))